import re, os
from scapy.all import *


# 一条UDP流（指四元组相同的UDP包构成的集合，单向的哦）
class Flow:
    def __init__(self, ip_src, ip_dst, port_src, port_dst):
        self.ip_src = ip_src
        self.ip_dst = ip_dst
        self.port_src = port_src
        self.port_dst = port_dst
        self.time_stamp_list = []
    
    def append_time_stamp(self, time_stamp):
        self.time_stamp_list.append(time_stamp)

    # BAAAAAB: get last A timestamp
    def get_last_time_stamp(self):
        last_time_stamp = self.time_stamp_list[-1]
        if len(self.time_stamp_list) > 1:
            last_time_stamp_1 = self.time_stamp_list[-2]
        else:
            last_time_stamp_1 = last_time_stamp
        self.time_stamp_list = []
        return last_time_stamp, last_time_stamp_1
    


class BasicUdp:
     
    def __init__(self, target_ip):
        self.flows_dict = {}            # 只记录 A - N 的单向流量
        self.tints = []
        self.target_ip = target_ip


    def add_packet(self, pk : Packet):
        time_stamp, ip_src, ip_dst, port_src, port_dst = pk.ts, pk.src_ip, pk.dst_ip, pk.src_port, pk.dst_port 

        if ip_dst == self.target_ip:          # A -> N
            key = '{}-{}-{}-{}'.format(ip_dst, ip_src, port_dst, port_src)
            if key not in self.flows_dict.keys():
                flow = Flow(ip_src, ip_dst, port_src, port_dst)
                self.flows_dict[key] = flow
            self.flows_dict[key].append_time_stamp(time_stamp)
        elif ip_src == self.target_ip:        # N -> A
            reverse_key = '{}-{}-{}-{}'.format(ip_src, ip_dst, port_src, port_dst)
            if reverse_key in self.flows_dict.keys():
                if self.flows_dict[reverse_key].time_stamp_list != []:
                    # 取最近一个数据包
                    last_time_stamp, last_time_stamp_1 = self.flows_dict[reverse_key].get_last_time_stamp()
                    tint1 = (float(time_stamp) - float(last_time_stamp))*1000
                    tint2 = (float(time_stamp) - float(last_time_stamp_1))*1000

                    if tint1 > 21000.0:
                        self.tints.append(0.0)
                    elif tint1 < 1.1:
                        if tint2 > 1.0 and tint2 < 21000.0:
                            self.tints.append(tint2)
                        else:
                            if len(self.tints) != 0:
                                self.tints.append(self.tints[-1])
                            else:
                                self.tints.append(0.0)
                    else:
                        if len(self.tints) != 0:
                            self.tints.append(self.tints[-1])
                        else:
                            self.tints.append(tint1)
                            #self.tints.append(tint1*1000)

    def get_tints(self) -> list :
        temp_list = self.tints
        res_tints = [x for x in temp_list if x != 0.0 and x < 21000.0]
        self.tints = []
        if (len(temp_list)>len(res_tints)) and (len(res_tints) == 0):
            res_tints2 = [x for x in temp_list if x > 99.9 and x < 21000.0]
            return res_tints2
        else:
            return res_tints
    
